# Main imports


from econml.dml import CausalForestDML

# Helper imports
import numpy as np

from sklearn.linear_model import Lasso, LogisticRegression
from causally.model.abstract_model import SKAbstractModel


class CF(SKAbstractModel):
    def __init__(self, config,dataset):

        super(CF, self).__init__(config,dataset)

        self.n_units = dataset.get_X_size()[0]
        self.n_estimators = config['n_estimators']
        self.min_samples_leaf = config['min_samples_leaf']
        self.max_depth = config['max_depth']
        self.discrete_treatment = config['discrete_treatment']
        self.subsample_ratio = config['subsample_ratio']
        self.random_state = config['random_state']
        self.lambda_reg = config['lambda_reg']
        self.lambda_reg  = np.sqrt(np.log(30) / (10 * self.subsample_ratio * self.n_units))

        self.model = CausalForestDML(
               model_y=Lasso(alpha=self.lambda_reg),
               model_t=LogisticRegression(C=1/(self.n_units*self.lambda_reg)),
               n_estimators=self.n_estimators, min_samples_leaf=self.min_samples_leaf,
               max_depth=self.max_depth, max_samples=self.subsample_ratio/2,
               discrete_treatment=self.discrete_treatment,
               random_state=self.random_state
        )


    def calculate_loss(self, x,t,y,w):

        self.model.fit(y,t,X=x)


    def predict(self, x,t_0,t_1):

        y_tau = self.model.effect(X=x,T0=t_0,T1=t_1)

        return y_tau